package io.gitee.h25094152.chatgpt;


import cn.hutool.http.HttpRequest;
import cn.hutool.json.JSONUtil;

import io.gitee.h25094152.chatgpt.entity.ChatGptParam;
import io.gitee.h25094152.chatgpt.entity.ChatGptResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.*;

/**
 * chatGpt 工具类
 */
@Slf4j
public class ChatGptUtil {

//    private static final String proxyIP = "127.0.0.1";
//    private static final Integer proxyProt = 1080;

    private static String proxyIP = "";
    private static Integer proxyProt = 0;

    private final String URL = "https://api.openai.com/v1/completions";
    private final String NEW_URL = "https://api.openai.com/v1/chat/completions";
    private final Integer MAX_TOKENS = 2048;
    //private static final String MODEL = "text-davinci-003";
    private String MODEL = "text-davinci-003";
    private String API_KEY = "";
    private static ChatGptUtil chatGptUtil ;

    /**
     * 创建实例
     * @return
     */
    public static ChatGptUtil getInstance() {
        if(ObjectUtils.isEmpty(ChatGptUtil.chatGptUtil)) {
            synchronized (ChatGptUtil.class){
                if(ChatGptUtil.chatGptUtil == null)
                    ChatGptUtil.chatGptUtil = new ChatGptUtil();
            }

        }
        return ChatGptUtil.chatGptUtil;
    }

    /**
     * 销毁
     */
    public void destory(){
        ChatGptUtil.chatGptUtil = null ;
    }

    /**
     * 添加ApiKey
     * @param API_KEY
     * @return
     */
    public ChatGptUtil setApiKey(String API_KEY) {
        this.API_KEY = API_KEY;
        return this;
    }

    /**
     * 设置模式
     * @param model
     * @return
     */
    public ChatGptUtil setModel(String model) {
        this.MODEL = model;
        return this;
    }

    /**
     * 添加代理
     * @param proxyIP
     * @param proxyProt
     * @return
     */
    public ChatGptUtil setProxy(String proxyIP, Integer proxyProt) {
        ChatGptUtil.proxyIP = proxyIP;
        ChatGptUtil.proxyProt = proxyProt;
        return this;
    }

    /**
     * 调用chatGPT接口-并添加代理
     * @param historyAnswer
     * @param newQuestion
     * @param proxyIP
     * @param proxyProt
     * @return
     */
    public String chatGPTRequest(List<ChatGptParam> historyAnswer, String newQuestion, String proxyIP, Integer proxyProt) {
        ChatGptUtil.proxyIP = proxyIP;
        ChatGptUtil.proxyProt = proxyProt;

        return this.chatGPTRequest(historyAnswer, newQuestion);
    }

    /**
     * 调用chatGPT接口
     *
     * @param historyAnswer 回答历史
     * @param newQuestion   新问题
     * @return
     * @describe 支持上下文问答
     * 请求示例：
     * 问：hello
     * 答：hello!
     * 问：Who are you?
     * 请求格式：(You:hello\n)hello!(You:Who are you?)
     */
    public String chatGPTRequest(List<ChatGptParam> historyAnswer, String newQuestion) {
        Assert.hasLength(newQuestion, "Question cannot be empty!");

        StringBuilder sb = new StringBuilder();
        //拼接历史回答
        if (!CollectionUtils.isEmpty(historyAnswer)) {
            for (ChatGptParam item : historyAnswer) {
                if (StringUtils.hasLength(item.getQuestion())) {
                    sb.append("(You:").append(item.getQuestion()).append("\n)");
                }
                if (StringUtils.hasLength(item.getAnswer())) {
                    sb.append(item.getQuestion());
                }
            }
        }
        //拼接新问题
        sb.append("(You:").append(newQuestion).append(")");

        //组装参数
        HashMap<String, Object> params = new HashMap<>();
        params.put("prompt", sb.toString());
        params.put("max_tokens", MAX_TOKENS);
        params.put("model", MODEL);
        //请求chatGPT接口

        HttpRequest httpRequest = HttpRequest.post(URL)
                .auth("Bearer " + API_KEY)
                .contentType("application/json")
                .body(JSONUtil.toJsonStr(params))
                .setReadTimeout(60000)
                .setConnectionTimeout(60000);

        if (ObjectUtils.isNotEmpty(proxyIP) && ObjectUtils.isNotEmpty(proxyProt)) {
            httpRequest.setHttpProxy(proxyIP, proxyProt);
        }

        String result = httpRequest.execute().body();

        ChatGptResult chatGptResult = JSONUtil.toBean(result, ChatGptResult.class);
        return chatGptResult.getChoices().get(0).getText();
    }

    /**
     * 基于ChatGPT最新模型调用
     *
     * @param historyAnswer
     * @param newQuestion
     * @return
     */
    public String chatGPTRequestNew(List<ChatGptParam> historyAnswer, String newQuestion) {
        Assert.hasLength(newQuestion, "Question cannot be empty!");

        //拼接历史回答
        List<Map<String, Object>> list = new ArrayList<>();
        if (!CollectionUtils.isEmpty(historyAnswer)) {
            for (ChatGptParam item : historyAnswer) {
                Map<String, Object> message = new HashMap<>();
                message.put("role", "user");
                message.put("content", item.getQuestion());
                list.add(message);

                message = new HashMap<>();
                message.put("role", "assistant");
                message.put("content", item.getAnswer());
                list.add(message);
            }
        }
        //拼接新问题
        Map<String, Object> message = new HashMap<>();
        message.put("role", "user");
        message.put("content", newQuestion);
        list.add(message);

        //组装参数
        HashMap<String, Object> params = new HashMap<>();
        params.put("messages", list);
        params.put("model", "gpt-3.5-turbo-0301"); //只能用gpt-3.5-turbo, gpt-3.5-turbo-0301
        //请求chatGPT接口
        String result = HttpRequest.post(this.NEW_URL)
                .auth("Bearer " + this.API_KEY)
                .contentType("application/json")
                .body(JSONUtil.toJsonStr(params))
                .execute()
                .body();
        ChatGptResult chatGptResult = JSONUtil.toBean(result, ChatGptResult.class);
        return chatGptResult.getChoices().get(0).getMessage().getContent();
    }

//    /**
//     * 调用测试
//     * @param args
//     */
//    public static void main(String[] args) {
//        Scanner scanner = new Scanner(System.in);
//        int number = 1; //提问次数
//        List<ChatGptParam> historyAnswer = new ArrayList<>();
//        System.out.println("ChatGPT问答开始");
//        while (number > 0) {
//            System.out.print("问：");
//            String question = scanner.nextLine();
//            if (question.equals("exit")) {
//                number--;
//                break;
//            }
//
//            String answer = ChatGptUtil.getInstance()
//                    .setApiKey("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")
//                    .setProxy("127.0.0.1", 1080)
//                    .chatGPTRequest(historyAnswer, question)
//                    .replaceAll("\n", "");
//            System.out.println("答：" + answer);
//            historyAnswer.add(ChatGptParam.builder().question(question).answer(answer).build());
//        }
//        System.out.println("ChatGPT问答结束");
//    }


}
